from functools import partial
import math
import numpy as np
import tensorflow as tf

##### Define class to get alignments and attention

class BahdanauAttention(tf.keras.layers.Layer):
    
    """
    This is a class for obtaining attention value, heavily inspired by two sources:

    https://www.tensorflow.org/tutorials/text/nmt_with_attention
    https://github.com/tensorflow/addons/tree/v0.7.1/tensorflow_addons/seq2seq

    Note that in nmt_with_attention, the return value of call function is directly attention (context_vector).
    This is a bit different from attention class in tensorflow_addons, which first returns alignments and then computes the attention outside of the class.
    Based on Kyubyong and tensorflow_addons code, it seems that current attention at timestep t (a_t)
    is computed using the attention at timestep t-1 (a_t-1), so I decided to change the attention class in nmt_with_atteition.

    """

    def __init__(self,
                 units, 
                 normalize = False, 
                 **kwargs):

        """
        Args:
            units: Number of units in dense layers in attention.
            normalize: Whether to normalize the score. If true, additional weights are used to compute normalized scores.
            **kwargs : Dictionary that contains other common argument for layer creation.        
        """

        super(BahdanauAttention, self).__init__(**kwargs)

        self.units = units
        self.normalize = normalize
        
        self.W1 = tf.keras.layers.Dense(self.units, name = "W_value", use_bias = False)
        self.W2 = tf.keras.layers.Dense(self.units, name = "W_query", use_bias = False)        

    def build(self, input_shape):
        
        super(BahdanauAttention, self).build(input_shape)

        self.attention_v = self.add_weight("attention_v", 
                                           [self.units],
                                           dtype=self.dtype,
                                           initializer= "glorot_uniform")
        
        if self.normalize:
            
            self.attention_g = self.add_weight("attention_g",
                                                dtype=self.dtype,
                                                initializer=tf.constant_initializer(
                                                    math.sqrt((1. / self.units))),
                                                shape=())
            
            self.attention_b = self.add_weight("attention_b", 
                                               [self.units],
                                               dtype=self.dtype,
                                               initializer=tf.zeros_initializer())

            
    def call(self, 
             query,
             values, 
             memory_mask = None):
        
        """
        Args:
            query: Input tensor to compute attention at current timestep. In this code, it's cell output values generated by final output tensor of previous time and previous attention tensor.
            values: Memory value computed before. Often, it's the final full-sequence (return_sequences = True) output value.
            memory_mask: Mask value for memory. Only applied if mask tensor is given.

        Returns:
            alignments: Probability value indicating in which memory timestep it's paying attention to.
            next_state: Next state value. Same as alignments in this case.
        """

        hidden_with_time_axis = tf.expand_dims(query, 1)
        
        if self.normalize:
            normed_v = self.attention_g * self.attention_v * tf.math.rsqrt(tf.reduce_sum(tf.square(self.attention_v)))
            score = tf.reduce_sum(normed_v * tf.tanh(self.W1(values) + self.W2(hidden_with_time_axis) + self.attention_b), 
                                  [2])
        else:
            score = tf.reduce_sum(self.attention_v * tf.tanh(self.W1(values) + self.W2(hidden_with_time_axis)), 
                                [2])

        ##### Masking process
        if memory_mask is not None:
            score = self._maybe_mask_score(score, memory_mask)
        
        alignments = tf.nn.softmax(score, axis=1) # alignments
        next_state = alignments # This is part of _calculate_attention function of attention mechanisms in tfa.seq2seq.attention_wrapper
        
        return alignments, next_state
    
    
    ##### Function to mask score
    def _maybe_mask_score(self, score, memory_mask):

        """
        Args:
            score: Score to be masked.
            memory_mask: Mask tensor

        Returns:
            outputs: Output tensor with mask applied
        """

        score_mask_value = score.dtype.min
        score_mask_values = score_mask_value * tf.ones_like(score)

        outputs = tf.where(memory_mask, score, score_mask_values)

        return outputs
    
    
    ##### Function to compute attention
    def _compute_attention(self, alignments, values):

        """
        Args:
            alignments: Alignment value computed at call() function.
            values: Memory value computed before. Often, it's the final full-sequence (return_sequences = True) output value.

        Returns:
            attention: Attention value computed.
            alignments: Alignments given
        """

        expanded_alignments = tf.expand_dims(alignments, 1)
        context_vector = tf.matmul(expanded_alignments, values)
        context_vector = tf.squeeze(context_vector, [1])

        # No spectific attention layer given
        attention = context_vector

        return attention, alignments




##### Define class to get Bahdanau Monotonic alignments and attention

class BahdanauMonotonicAttention(tf.keras.layers.Layer):
    
    def __init__(self, 
                 units, 
                 sigmoid_noise = 0.0, 
                 normalize = False, 
                 **kwargs):
        
        self.units = units
        self.sigmoid_noise = sigmoid_noise
        self.normalize = normalize
        
        self.W1 = tf.keras.layers.Dense(self.units, name = "W_value", use_bias = False)
        self.W2 = tf.keras.layers.Dense(self.units, name = "W_query", use_bias = False)
        
        self.probability_fn = partial(self._monotonic_probability_fn,
                                      sigmoid_noise = self.sigmoid_noise)
        
        super(BahdanauMonotonicAttention, self).__init__(**kwargs)

        
    def build(self, input_shape):
        
        self.attention_score_bias = self.add_weight("attention_score_bias",
                                                    shape = (),
                                                    dtype = self.dtype,
                                                    initializer = tf.constant_initializer(0.0))
        
        self.attention_v = self.add_weight("attention_v", 
                                           [self.units],
                                           dtype=self.dtype,
                                           initializer= "glorot_uniform")
        
        if self.normalize:
            
            self.attention_g = self.add_weight("attention_g",
                                                dtype=self.dtype,
                                                initializer=tf.constant_initializer(math.sqrt((1. / self.units))),
                                                shape=())
            
            self.attention_b = self.add_weight("attention_b", 
                                               [self.units],
                                               dtype=self.dtype,
                                               initializer=tf.zeros_initializer())
        
        super(BahdanauMonotonicAttention, self).build(input_shape)

        
    def call(self, 
             query, 
             values, 
             previous_alignments, 
             memory_mask = None):

        hidden_with_time_axis = tf.expand_dims(query, 1) # Processed query
        
        if self.normalize:
            normed_v = self.attention_g * self.attention_v * tf.math.rsqrt(tf.reduce_sum(tf.square(self.attention_v)))
            score = tf.reduce_sum(normed_v * tf.tanh(self.W1(values) + self.W2(hidden_with_time_axis) + self.attention_b), 
                                  [2])
        else:
            score = tf.reduce_sum(self.attention_v * tf.tanh(self.W1(values) + self.W2(hidden_with_time_axis)), 
                                  [2])

        ##### Masking process
        if memory_mask is not None:
            score = self._maybe_mask_score(score, memory_mask)
        
        score += self.attention_score_bias
        
        alignments = self.probability_fn(score, previous_alignments)
        next_attention_state = alignments

        return alignments, next_attention_state

    
    ##### Define function to mask score
    def _maybe_mask_score(self, score, memory_mask):

        score_mask_value = score.dtype.min
        score_mask_values = score_mask_value * tf.ones_like(score)

        return tf.where(memory_mask, score, score_mask_values)
    
    
    ##### Define function to get monotonic probability
    def _monotonic_probability_fn(self,
                                  score,
                                  previous_alignments,
                                  sigmoid_noise):
        """
        Attention probability function for monotonic attention obteind from
        (https://github.com/tensorflow/addons/blob/1af92905ed03f05fcf6f4918783c3d151a8b8350/tensorflow_addons/seq2seq/attention_wrapper.py#L810)
        """
        # Optionally add pre-sigmoid noise to the scores
        if sigmoid_noise > 0:
            noise = tf.random.normal(tf.shape(score), dtype=score.dtype)
            score += sigmoid_noise * noise
        # Compute "choosing" probabilities from the attention scores

        p_choose_i = tf.sigmoid(score) # probability of selecting i
        
        return self._monotonic_attention(p_choose_i, previous_alignments)
    
    
    ##### Define function to get monotonic attention
    def _monotonic_attention(self, p_choose_i, previous_alignments):
        """
        Function for computing monotonic attention brought from tensorflow.addon github
        (https://github.com/tensorflow/addons/blob/1af92905ed03f05fcf6f4918783c3d151a8b8350/tensorflow_addons/seq2seq/attention_wrapper.py#L927).
        In this case, we're only dealing with the "parallel" mode.
        """
        # Force things to be tensors
        p_choose_i = tf.convert_to_tensor(p_choose_i, name="p_choose_i")
        previous_alignments = tf.convert_to_tensor(
            previous_alignments, name="previous_attention")

        # safe_cumprod computes cumprod in logspace with numeric checks
        cumprod_1mp_choose_i = self._safe_cumprod(
            1 - p_choose_i, axis=1, exclusive=True)

        # Compute recurrence relation solution
        attention = p_choose_i * cumprod_1mp_choose_i * tf.cumsum(
            previous_alignments /
            # Clip cumprod_1mp to avoid divide-by-zero
            tf.clip_by_value(cumprod_1mp_choose_i, 1e-10, 1.),
            axis=1)

        return attention

    ##### Define function to be used for monotonic attention
    def _safe_cumprod(self, x, *args, **kwargs):
    
        x = tf.convert_to_tensor(x, name="x")
        tiny = np.finfo(x.dtype.as_numpy_dtype).tiny
        
        return tf.exp(
            tf.cumsum(
                tf.math.log(tf.clip_by_value(x, tiny, 1)), *args, **kwargs))
    
    def _compute_attention(self, alignments, values):
    
        expanded_alignments = tf.expand_dims(alignments, 1)
        context_vector = tf.matmul(expanded_alignments, values)
        context_vector = tf.squeeze(context_vector, [1])

        # No spectific attention layer given
        attention = context_vector

        return attention, alignments